import java.util.List;

public class RemoveNthNodeFromEndOfList {
    public ListNode removeNthFromEnd(ListNode head, int n) {
        if(head.next==null){
            return null;
        }
        if (head.next.next==null){
            if (n==1)
                return head;
            else
                return head.next;
        }
        ListNode pre = head;
        ListNode last = head;
        while (n>1){
            last = last.next;
            n--;
        }
        ListNode tmp = pre;
        while (last.next!=null){
            tmp = pre;
            pre = pre.next;
            last = last.next;
        }
        tmp.next=pre.next;
        pre.next=null;
        return head;


    }
}
